{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Adversarial Networks\n",
    "\n",
    "In this notebook, we play with the GAN described in the [course](https://mlelarge.github.io/dataflowr/Slides/GAN/index.html) on a double moon dataset.\n",
    "\n",
    "Then we implement a Conditional GAN and an InfoGAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all of these libraries are used for plotting\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x_min, x_max = -1.4, 2.8\n",
    "y_min, y_max = -4, 2.8\n",
    "\n",
    "# Plot the dataset\n",
    "def plot_data(ax, X, Y, color = 'bone'):\n",
    "    plt.axis('off')\n",
    "    ax.scatter(X[:, 0], X[:, 1], s=1, c=Y, cmap=color)\n",
    "    plt.axis([x_min, x_max, y_min, y_max])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# load the data and visualize it\n",
    "data = np.load('./double_moon.npz')\n",
    "X = data['X']\n",
    "n_samples = X.shape[0]\n",
    "Y = np.ones(n_samples)\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "ax.set_xlim(x_min, x_max)\n",
    "ax.set_ylim(y_min, y_max)\n",
    "plot_data(ax, X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "use_gpu = torch.cuda.is_available()\n",
    "def gpu(tensor, gpu=use_gpu):\n",
    "    if gpu:\n",
    "        return tensor.cuda()\n",
    "    else:\n",
    "        return tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simple GAN\n",
    "\n",
    "We start with the simple GAN described in the course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "z_dim = 32\n",
    "hidden_dim = 128\n",
    "\n",
    "net_G = nn.Sequential(nn.Linear(z_dim,hidden_dim),\n",
    "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
    "\n",
    "net_D = nn.Sequential(nn.Linear(2,hidden_dim),\n",
    "                     nn.ReLU(),\n",
    "                     nn.Linear(hidden_dim,1),\n",
    "                     nn.Sigmoid())\n",
    "\n",
    "net_G = gpu(net_G)\n",
    "net_D = gpu(net_D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training loop as described in the course, keeping the losses for the discriminator and the generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "lr = 1e-4\n",
    "nb_epochs = 1000\n",
    "\n",
    "optimizer_G = torch.optim.Adam(net_G.parameters(),lr=lr)\n",
    "optimizer_D = torch.optim.Adam(net_D.parameters(),lr=lr)\n",
    "\n",
    "loss_D_epoch = []\n",
    "loss_G_epoch = []\n",
    "for e in range(nb_epochs):\n",
    "    np.random.shuffle(X)\n",
    "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
    "    loss_G = 0\n",
    "    loss_D = 0\n",
    "    for t, real_batch in enumerate(real_samples.split(batch_size)):\n",
    "            #improving D\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        fake_batch = net_G(z)\n",
    "        D_scores_on_real = net_D(gpu(real_batch))\n",
    "        D_scores_on_fake = net_D(fake_batch)\n",
    "            \n",
    "        loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
    "        optimizer_D.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_D.step()\n",
    "        loss_D += loss\n",
    "                    \n",
    "            # improving G\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        fake_batch = net_G(z)\n",
    "        D_scores_on_fake = net_D(fake_batch)\n",
    "            \n",
    "        loss = -torch.mean(torch.log(D_scores_on_fake))\n",
    "        optimizer_G.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_G.step()\n",
    "        loss_G += loss\n",
    "           \n",
    "    loss_D_epoch.append(loss_D)\n",
    "    loss_G_epoch.append(loss_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_D_epoch)\n",
    "plt.plot(loss_G_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
    "fake_samples = net_G(z)\n",
    "fake_data = fake_samples.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "ax.set_xlim(x_min, x_max)\n",
    "ax.set_ylim(y_min, y_max)\n",
    "all_data = np.concatenate((X,fake_data),axis=0)\n",
    "Y2 = np.concatenate((np.ones(n_samples),np.zeros(n_samples)))\n",
    "plot_data(ax, all_data, Y2)\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can generate more points:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(10*n_samples,z_dim).normal_())\n",
    "fake_samples = net_G(z)\n",
    "fake_data = fake_samples.cpu().data.numpy()\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "all_data = np.concatenate((X,fake_data),axis=0)\n",
    "Y2 = np.concatenate((np.ones(n_samples),np.zeros(10*n_samples)))\n",
    "plot_data(ax, all_data, Y2)\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conditional GAN\n",
    "\n",
    "We are now implementing a [conditional GAN](https://arxiv.org/abs/1411.1784).\n",
    "\n",
    "We start by separating the two half moons in two clusters as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = data['X']\n",
    "n_samples = X.shape[0]\n",
    "Y = data['Y']\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "ax.set_xlim(x_min, x_max)\n",
    "ax.set_ylim(y_min, y_max)\n",
    "plot_data(ax, X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The task is now given a white or black label to generate points in the corresponding cluster.\n",
    "\n",
    "Both the generator and the discriminator take in addition a one hot encoding of the label. The generator will now generate fake points corresponding to the input label. The discriminator, given a pair of sample and label should detect if this is a fake or a real pair."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_dim = 32\n",
    "hidden_dim = 128\n",
    "label_dim = 2\n",
    "\n",
    "\n",
    "class generator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(generator,self).__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(z_dim+label_dim,hidden_dim),\n",
    "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
    "        \n",
    "    def forward(self, input, label_onehot):\n",
    "        x = torch.cat([input, label_onehot], 1)\n",
    "        return self.net(x)\n",
    "    \n",
    "class discriminator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(discriminator,self).__init__()\n",
    "        self.net =  nn.Sequential(nn.Linear(2+label_dim,hidden_dim),\n",
    "                     nn.ReLU(),\n",
    "                     nn.Linear(hidden_dim,1),\n",
    "                     nn.Sigmoid())\n",
    "        \n",
    "    def forward(self, input, label_onehot):\n",
    "        x = torch.cat([input, label_onehot], 1)\n",
    "        return self.net(x)\n",
    "        \n",
    "\n",
    "net_CG = gpu(generator())\n",
    "net_CD = gpu(discriminator())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to code the training loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "lr = 1e-4\n",
    "nb_epochs = 1000\n",
    "\n",
    "optimizer_CG = torch.optim.Adam(net_CG.parameters(),lr=lr)\n",
    "optimizer_CD = torch.optim.Adam(net_CD.parameters(),lr=lr)\n",
    "loss_D_epoch = []\n",
    "loss_G_epoch = []\n",
    "for e in range(nb_epochs):\n",
    "    rperm = np.random.permutation(X.shape[0]);\n",
    "    np.take(X,rperm,axis=0,out=X);\n",
    "    np.take(Y,rperm,axis=0,out=Y);\n",
    "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
    "    real_labels = torch.from_numpy(Y).type(torch.LongTensor)\n",
    "    loss_G = 0\n",
    "    loss_D = 0\n",
    "    for real_batch, real_batch_label in zip(real_samples.split(batch_size),real_labels.split(batch_size)):\n",
    "            #improving D\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        # hint: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4\n",
    "        #\n",
    "        \n",
    "        loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
    "        optimizer_CD.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_CD.step()\n",
    "        loss_D += loss\n",
    "            \n",
    "            # improving G\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "            \n",
    "        loss = -torch.mean(torch.log(D_scores_on_fake))\n",
    "        optimizer_CG.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_CG.step()\n",
    "        loss_G += loss\n",
    "                    \n",
    "    loss_D_epoch.append(loss_D)\n",
    "    loss_G_epoch.append(loss_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_D_epoch)\n",
    "plt.plot(loss_G_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.empty(n_samples,z_dim).normal_().cuda()\n",
    "label = torch.LongTensor(n_samples,1).random_() % label_dim\n",
    "label_onehot = torch.FloatTensor(n_samples, label_dim).zero_()\n",
    "label_onehot = label_onehot.scatter_(1, label, 1).cuda()\n",
    "fake_samples = net_CG(z, label_onehot)\n",
    "fake_data = fake_samples.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "ax.set_xlim(x_min, x_max)\n",
    "ax.set_ylim(y_min, y_max)\n",
    "plot_data(ax, fake_data, label.squeeze().numpy())\n",
    "plot_data(ax, X, Y, 'spring')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Info GAN\n",
    "\n",
    "Implement the [algorithm](https://arxiv.org/abs/1606.03657)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py3]",
   "language": "python",
   "name": "conda-env-py3-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
